b99c65
@@ -51,7 +51,7 @@
import org.apache.hadoop.util.ToolRunner;
  * {@link TotalOrderPartitioner}.
  *
  * This is an identical copy of o.a.h.mapreduce.lib.partition.TotalOrderPartitioner
- * from Hadoop trunk at r910774, with the exception of replacing
+ * from Hadoop trunk at r961542, with the exception of replacing
  * TaskAttemptContextImpl with TaskAttemptContext.
  */
 public class InputSampler<K,V> extends Configured implements Tool  {
@@ -63,7 +63,7 @@
public class InputSampler<K,V> extends Configured implements Tool  {
       "      [-inFormat <input format class>]\n" +
       "      [-keyClass <map input & output key class>]\n" +
       "      [-splitRandom <double pcnt> <numSamples> <maxsplits> | " +
-      "// Sample from random splits at random (general)\n" +
+      "             // Sample from random splits at random (general)\n" +
       "       -splitSample <numSamples> <maxsplits> | " +
       "             // Sample from first records in splits (random data)\n"+
       "       -splitInterval <double pcnt> <maxsplits>]" +
@@ -129,16 +129,17 @@
public class InputSampler<K,V> extends Configured implements Tool  {
       List<InputSplit> splits = inf.getSplits(job);
       ArrayList<K> samples = new ArrayList<K>(numSamples);
       int splitsToSample = Math.min(maxSplitsSampled, splits.size());
-      int splitStep = splits.size() / splitsToSample;
       int samplesPerSplit = numSamples / splitsToSample;
       long records = 0;
       for (int i = 0; i < splitsToSample; ++i) {
+        TaskAttemptContext samplingContext = new TaskAttemptContext(
+            job.getConfiguration(), new TaskAttemptID());
         RecordReader<K,V> reader = inf.createRecordReader(
-          splits.get(i * splitStep), 
-          new TaskAttemptContext(job.getConfiguration(), 
-                                 new TaskAttemptID()));
+            splits.get(i), samplingContext);
+        reader.initialize(splits.get(i), samplingContext);
         while (reader.nextKeyValue()) {
-          samples.add(reader.getCurrentKey());
+          samples.add(ReflectionUtils.copy(job.getConfiguration(),
+                                           reader.getCurrentKey(), null));
           ++records;
           if ((i+1) * samplesPerSplit <= records) {
             break;
@@ -213,13 +214,16 @@
public class InputSampler<K,V> extends Configured implements Tool  {
       // the target sample keyset
       for (int i = 0; i < splitsToSample ||
                      (i < splits.size() && samples.size() < numSamples); ++i) {
-        RecordReader<K,V> reader = inf.createRecordReader(splits.get(i), 
-          new TaskAttemptContext(job.getConfiguration(), 
-                                 new TaskAttemptID()));
+        TaskAttemptContext samplingContext = new TaskAttemptContext(
+            job.getConfiguration(), new TaskAttemptID());
+        RecordReader<K,V> reader = inf.createRecordReader(
+            splits.get(i), samplingContext);
+        reader.initialize(splits.get(i), samplingContext);
         while (reader.nextKeyValue()) {
           if (r.nextDouble() <= freq) {
             if (samples.size() < numSamples) {
-              samples.add(reader.getCurrentKey());
+              samples.add(ReflectionUtils.copy(job.getConfiguration(),
+                                               reader.getCurrentKey(), null));
             } else {
               // When exceeding the maximum number of samples, replace a
               // random element with this one, then adjust the frequency
@@ -227,7 +231,8 @@
public class InputSampler<K,V> extends Configured implements Tool  {
               // pushed out
               int ind = r.nextInt(numSamples);
               if (ind != numSamples) {
-                samples.set(ind, reader.getCurrentKey());
+                samples.set(ind, ReflectionUtils.copy(job.getConfiguration(),
+                                 reader.getCurrentKey(), null));
               }
               freq *= (numSamples - 1) / (double) numSamples;
             }
@@ -277,19 +282,20 @@
public class InputSampler<K,V> extends Configured implements Tool  {
       List<InputSplit> splits = inf.getSplits(job);
       ArrayList<K> samples = new ArrayList<K>();
       int splitsToSample = Math.min(maxSplitsSampled, splits.size());
-      int splitStep = splits.size() / splitsToSample;
       long records = 0;
       long kept = 0;
       for (int i = 0; i < splitsToSample; ++i) {
+        TaskAttemptContext samplingContext = new TaskAttemptContext(
+            job.getConfiguration(), new TaskAttemptID());
         RecordReader<K,V> reader = inf.createRecordReader(
-          splits.get(i * splitStep),
-          new TaskAttemptContext(job.getConfiguration(), 
-                                 new TaskAttemptID()));
+            splits.get(i), samplingContext);
+        reader.initialize(splits.get(i), samplingContext);
         while (reader.nextKeyValue()) {
           ++records;
           if ((double) kept / records < freq) {
+            samples.add(ReflectionUtils.copy(job.getConfiguration(),
+                                 reader.getCurrentKey(), null));
             ++kept;
-            samples.add(reader.getCurrentKey());
           }
         }
         reader.close();
